{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as stats\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pickle\n",
    "\n",
    "import random\n",
    "import os\n",
    "\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from tqdm import tqdm\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras.layers as L\n",
    "import tensorflow.keras.models as M\n",
    "import tensorflow.keras.backend as K\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow_addons.layers import WeightNormalization\n",
    "from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.python.client import device_lib\n",
    "def get_available_gpus():\n",
    "    local_device_protos = device_lib.list_local_devices()\n",
    "    return [x.name for x in local_device_protos if x.device_type == 'GPU']\n",
    "print(get_available_gpus())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>site_path_timestamp</th>\n",
       "      <th>site</th>\n",
       "      <th>path</th>\n",
       "      <th>ts_waypoint</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>9017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>15326</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>18763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5a0546857ecc773753327266_046cfa46be49fc1083481...</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>046cfa46be49fc10834815c6</td>\n",
       "      <td>22328</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 site_path_timestamp  \\\n",
       "0  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "1  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "2  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "3  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "4  5a0546857ecc773753327266_046cfa46be49fc1083481...   \n",
       "\n",
       "                       site                      path  ts_waypoint  \n",
       "0  5a0546857ecc773753327266  046cfa46be49fc10834815c6            9  \n",
       "1  5a0546857ecc773753327266  046cfa46be49fc10834815c6         9017  \n",
       "2  5a0546857ecc773753327266  046cfa46be49fc10834815c6        15326  \n",
       "3  5a0546857ecc773753327266  046cfa46be49fc10834815c6        18763  \n",
       "4  5a0546857ecc773753327266  046cfa46be49fc10834815c6        22328  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "# PATH = '../input/indoor-location-navigation'\n",
    "# test_files = glob.glob(f'{PATH}/test/*.txt')\n",
    "# test_files_pd = [xx.split('/')[-1:][0].replace('.txt','') for xx in test_files]\n",
    "# test_files_pd = pd.DataFrame(test_files_pd)\n",
    "# test_files_pd.columns = ['path']\n",
    "\n",
    "sample_submission = pd.read_csv(\"../input/indoor-location-navigation/sample_submission.csv\")\n",
    "sample_submission['site'] = [xx.split('_')[0] for xx in sample_submission.site_path_timestamp]\n",
    "sample_submission['path'] = [xx.split('_')[1] for xx in sample_submission.site_path_timestamp]\n",
    "sample_submission['ts_waypoint'] = [int(xx.split('_')[2]) for xx in sample_submission.site_path_timestamp]\n",
    "del sample_submission['floor']\n",
    "del sample_submission['x']\n",
    "del sample_submission['y']\n",
    "\n",
    "path2site = dict(zip(sample_submission.path,sample_submission.site))\n",
    "sample_submission.head()\n",
    "# test_path_site = sample_submission[['site','path','timestamp','site_path_timestamp']]\n",
    "# test_files_pd = pd.merge(test_files_pd,test_path_site,how='left',on='path')\n",
    "# test_files_pd.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_wifi_files = glob.glob(f'../input/wifi_lbl_encode/test/*.txt')\n",
    "\n",
    "# train_files = glob.glob('../input/indoor-navigation-and-location-wifi-features-alldata/*train.csv') #if A \n",
    "train_files = glob.glob('../input/data_abstract/*_train_waypoint_all.csv')#if B\n",
    "\n",
    "    \n",
    "train_wifi_files = glob.glob(f'../input/wifi_lbl_encode/train/*/*/*.txt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../input/data_abstract/5a0546857ecc773753327266_train_waypoint_all.csv',\n",
       " '../input/data_abstract/5c3c44b80379370013e0fd2b_train_waypoint_all.csv']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_files[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len train site list: 24\n"
     ]
    }
   ],
   "source": [
    "# train_site_list = [xx.split('/')[-1].replace('_train.csv','') for xx in train_files] #if A \n",
    "# train_site_list = [xx.split('/')[-1].replace('_train_waypoint_all.csv','') for xx in train_files] #if B 204\n",
    "train_site_list = list(sample_submission.site.unique()) # if B 24\n",
    "train_wifi_files = [xx for xx in train_wifi_files if xx.split('/')[-3] in train_site_list]\n",
    "print('len train site list:',len(train_site_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10877"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_wifi_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11503/11503 [01:02<00:00, 184.06it/s]\n"
     ]
    }
   ],
   "source": [
    "ssidlist = set()\n",
    "bssidlist = set()\n",
    "for filename in tqdm(train_wifi_files+test_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    ssidlist = ssidlist|set(tmp.ssid)\n",
    "    bssidlist = bssidlist|set(tmp.bssid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20044, 65952)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(ssidlist)),len(set(bssidlist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "seqlen = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ssiddict = dict(zip(list(ssidlist)+['empty'],range(len(ssidlist)+1)))\n",
    "bssiddict = dict(zip(list(bssidlist)+['empty'],range(len(bssidlist)+1)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [00:42<00:00, 254.02it/s]\n"
     ]
    }
   ],
   "source": [
    "train_wifi_pd_csv = []\n",
    "for filename in tqdm(train_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    tmp['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    tmp['floor'] = filename.split('/')[-2]\n",
    "    tmp['site'] = filename.split('/')[-3]\n",
    "    train_wifi_pd_csv.append(tmp)\n",
    "train_wifi_pd_csv = pd.concat(train_wifi_pd_csv).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "floor_map = {\"B3\":-3,\"B2\":-2, \"B1\":-1, \"F1\":0, \"F2\": 1, \"F3\":2, \"F4\":3, \"F5\":4, \"F6\":5, \"F7\":6,\"F8\":7, \"F9\":8,\n",
    "             \"1F\":0, \"2F\":1, \"3F\":2, \"4F\":3, \"5F\":4, \"6F\":5, \"7F\":6, \"8F\": 7, \"9F\":8}\n",
    "train_wifi_pd_csv = train_wifi_pd_csv[train_wifi_pd_csv.floor.isin(floor_map)].reset_index(drop=True)\n",
    "train_wifi_pd_csv['floorNo'] = train_wifi_pd_csv['floor'].apply(lambda x: floor_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>floorNo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>63159</td>\n",
       "      <td>162932</td>\n",
       "      <td>-46</td>\n",
       "      <td>1578462603277</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>32835</td>\n",
       "      <td>65513</td>\n",
       "      <td>-49</td>\n",
       "      <td>1578462618272</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp   ssid   bssid  rssi  last_timestamp  \\\n",
       "0  1578462618826  63159  162932   -46   1578462603277   \n",
       "1  1578462618826  32835   65513   -49   1578462618272   \n",
       "\n",
       "                       path floor                      site  floorNo  \n",
       "0  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266       -1  \n",
       "1  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266       -1  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:02<00:00, 208.96it/s]\n"
     ]
    }
   ],
   "source": [
    "test_wifi_pd_csv = []\n",
    "for filename in tqdm(test_wifi_files):\n",
    "    tmp = pd.read_csv(filename)\n",
    "    tmp['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    test_wifi_pd_csv.append(tmp)\n",
    "test_wifi_pd_csv = pd.concat(test_wifi_pd_csv).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1961</td>\n",
       "      <td>70537</td>\n",
       "      <td>28318</td>\n",
       "      <td>-34</td>\n",
       "      <td>1571828560156</td>\n",
       "      <td>14f45baa63b4d3a700126af6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1961</td>\n",
       "      <td>43838</td>\n",
       "      <td>93116</td>\n",
       "      <td>-35</td>\n",
       "      <td>1571828560159</td>\n",
       "      <td>14f45baa63b4d3a700126af6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp   ssid  bssid  rssi  last_timestamp                      path\n",
       "0       1961  70537  28318   -34   1571828560156  14f45baa63b4d3a700126af6\n",
       "1       1961  43838  93116   -35   1571828560159  14f45baa63b4d3a700126af6"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission = pd.read_csv('submission_floor.csv')\n",
    "submission['path'] = [xx.split('_')[1] for xx in submission['site_path_timestamp']]\n",
    "test_path_floor_dict = dict(zip(submission.path,submission.floor))\n",
    "test_wifi_pd_csv['floorNo'] = [test_path_floor_dict[xx] for xx in test_wifi_pd_csv['path']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ss = StandardScaler()\n",
    "ss.fit(train_wifi_pd_csv.loc[:,['rssi','floorNo']])\n",
    "train_wifi_pd_csv.loc[:,['rssi','floorNo']] = ss.transform(train_wifi_pd_csv.loc[:,['rssi','floorNo']])\n",
    "test_wifi_pd_csv.loc[:,['rssi','floorNo']] = ss.transform(test_wifi_pd_csv.loc[:,['rssi','floorNo']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>last_timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>floorNo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>63159</td>\n",
       "      <td>162932</td>\n",
       "      <td>3.105926</td>\n",
       "      <td>1578462603277</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1.340327</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1578462618826</td>\n",
       "      <td>32835</td>\n",
       "      <td>65513</td>\n",
       "      <td>2.810727</td>\n",
       "      <td>1578462618272</td>\n",
       "      <td>5e15730aa280850006f3d005</td>\n",
       "      <td>B1</td>\n",
       "      <td>5a0546857ecc773753327266</td>\n",
       "      <td>-1.340327</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp   ssid   bssid      rssi  last_timestamp  \\\n",
       "0  1578462618826  63159  162932  3.105926   1578462603277   \n",
       "1  1578462618826  32835   65513  2.810727   1578462618272   \n",
       "\n",
       "                       path floor                      site   floorNo  \n",
       "0  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266 -1.340327  \n",
       "1  5e15730aa280850006f3d005    B1  5a0546857ecc773753327266 -1.340327  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd_csv.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [02:51<00:00, 63.43it/s] \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.088208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.040317</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  wifi_median  wifi_std  \n",
       "0     0.206   0.353603     0.350737  1.088208  \n",
       "1     0.220   0.299748     0.350737  1.040317  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_wifi_pd = []\n",
    "for path,tmp in tqdm(train_wifi_pd_csv.groupby('path')):\n",
    "    #tmp = pd.read_csv(filename)\n",
    "    #tmp['rssi'] = tmp['rssi']/999\n",
    "    tmp['ssid'] = tmp['ssid'].apply(lambda x: ssiddict[x])\n",
    "    tmp['bssid'] = tmp['bssid'].apply(lambda x: bssiddict[x])\n",
    "    ss1 = tmp.groupby('timestamp')['ssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[ssiddict['empty']]*(seqlen-len(x))) \n",
    "    ss2 = tmp.groupby('timestamp')['bssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[bssiddict['empty']]*(seqlen-len(x)))\n",
    "    ss3 = tmp.groupby('timestamp')['rssi'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[-10]*(seqlen-len(x)))\n",
    "    \n",
    "    ss = pd.concat([ss1,ss2,ss3],axis=1)\n",
    "    ss['path'] = tmp.path.unique()[0]\n",
    "    ss['floorNo'] = tmp.floorNo.unique()[0]\n",
    "    ss['floor'] = tmp.floor.unique()[0]\n",
    "    ss['site'] = tmp.site.unique()[0]\n",
    "    ss['wifi_len'] = tmp.groupby('timestamp')['rssi'].count()/500\n",
    "    ss['wifi_mean'] = tmp.groupby('timestamp')['rssi'].mean()\n",
    "    ss['wifi_median'] = tmp.groupby('timestamp')['rssi'].median()\n",
    "    ss['wifi_std'] = tmp.groupby('timestamp')['rssi'].std()\n",
    "\n",
    "    train_wifi_pd.append(ss)\n",
    "train_wifi_pd = pd.concat(train_wifi_pd)\n",
    "train_wifi_pd = train_wifi_pd.reset_index()\n",
    "train_wifi_pd.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:14<00:00, 41.79it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1180</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3048</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4924</td>\n",
       "      <td>[9522, 18669, 7007, 19396, 15215, 4851, 15215,...</td>\n",
       "      <td>[10783, 4531, 35106, 19211, 48757, 11767, 3933...</td>\n",
       "      <td>[1.4331324770012857, 1.2363332562632146, 1.039...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.048</td>\n",
       "      <td>-0.149461</td>\n",
       "      <td>-0.436460</td>\n",
       "      <td>0.815521</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6816</td>\n",
       "      <td>[18669, 4851, 15215, 7007, 9522, 19396, 19396,...</td>\n",
       "      <td>[4531, 11767, 39335, 35106, 10783, 19211, 5710...</td>\n",
       "      <td>[1.826730918477428, 1.1379336458941791, 1.0395...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.052</td>\n",
       "      <td>-0.118554</td>\n",
       "      <td>-0.534860</td>\n",
       "      <td>0.911802</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8693</td>\n",
       "      <td>[18669, 15215, 7007, 4851, 9522, 19396, 15215,...</td>\n",
       "      <td>[4531, 48757, 35106, 11767, 10783, 19211, 3933...</td>\n",
       "      <td>[2.1219297495845346, 1.3347328666322502, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.062</td>\n",
       "      <td>-0.182526</td>\n",
       "      <td>-0.534860</td>\n",
       "      <td>0.905339</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0       1180  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1       3048  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "2       4924  [9522, 18669, 7007, 19396, 15215, 4851, 15215,...   \n",
       "3       6816  [18669, 4851, 15215, 7007, 9522, 19396, 19396,...   \n",
       "4       8693  [18669, 15215, 7007, 4851, 9522, 19396, 15215,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "2  [10783, 4531, 35106, 19211, 48757, 11767, 3933...   \n",
       "3  [4531, 11767, 39335, 35106, 10783, 19211, 5710...   \n",
       "4  [4531, 48757, 35106, 11767, 10783, 19211, 3933...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "2  [1.4331324770012857, 1.2363332562632146, 1.039...   \n",
       "3  [1.826730918477428, 1.1379336458941791, 1.0395...   \n",
       "4  [2.1219297495845346, 1.3347328666322502, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "2  00ff0c9a71cc37a2ebdd0f05  0.845957     0.048  -0.149461    -0.436460   \n",
       "3  00ff0c9a71cc37a2ebdd0f05  0.845957     0.052  -0.118554    -0.534860   \n",
       "4  00ff0c9a71cc37a2ebdd0f05  0.845957     0.062  -0.182526    -0.534860   \n",
       "\n",
       "   wifi_std                      site  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547  \n",
       "2  0.815521  5da1389e4db8ce0c98bd0547  \n",
       "3  0.911802  5da1389e4db8ce0c98bd0547  \n",
       "4  0.905339  5da1389e4db8ce0c98bd0547  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd = []\n",
    "# for filename in tqdm(test_wifi_files):\n",
    "for path,tmp in tqdm(test_wifi_pd_csv.groupby('path')):\n",
    "    #tmp = pd.read_csv(filename)\n",
    "    #tmp['rssi'] = tmp['rssi']/999\n",
    "    tmp['ssid'] = tmp['ssid'].apply(lambda x: ssiddict[x])\n",
    "    tmp['bssid'] = tmp['bssid'].apply(lambda x: bssiddict[x])\n",
    "    ss1 = tmp.groupby('timestamp')['ssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[ssiddict['empty']]*(seqlen-len(x))) \n",
    "    ss2 = tmp.groupby('timestamp')['bssid'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[bssiddict['empty']]*(seqlen-len(x)))\n",
    "    ss3 = tmp.groupby('timestamp')['rssi'].apply(lambda x: \\\n",
    "                                list(x)[:seqlen] if len(x)>seqlen else list(x)+[-10]*(seqlen-len(x)))\n",
    "    ss = pd.concat([ss1,ss2,ss3],axis=1)\n",
    "    #ss['path'] = filename.split('/')[-1].replace('.txt','')\n",
    "    ss['path'] = tmp.path.unique()[0]\n",
    "    ss['floorNo'] = tmp.floorNo.unique()[0]\n",
    "    ss['wifi_len'] = tmp.groupby('timestamp')['rssi'].count()/500\n",
    "    ss['wifi_mean'] = tmp.groupby('timestamp')['rssi'].mean()\n",
    "    ss['wifi_median'] = tmp.groupby('timestamp')['rssi'].median()\n",
    "    ss['wifi_std'] = tmp.groupby('timestamp')['rssi'].std()\n",
    "\n",
    "    test_wifi_pd.append(ss)\n",
    "test_wifi_pd = pd.concat(test_wifi_pd)\n",
    "test_wifi_pd = test_wifi_pd.reset_index()\n",
    "test_wifi_pd['site'] = [path2site[xx] for xx in test_wifi_pd.path]\n",
    "test_wifi_pd.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 204/204 [00:00<00:00, 263.20it/s]\n"
     ]
    }
   ],
   "source": [
    "# filename = train_files[0]\n",
    "train_xy = []\n",
    "for filename in tqdm(train_files):\n",
    "    tmp = pd.read_csv(filename,index_col=0)\n",
    "    ss = tmp[['path','site','floor','ts_waypoint','x','y']]\n",
    "    train_xy.append(ss)\n",
    "train_xy = pd.concat(train_xy).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(166681, 6)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_xy=train_xy.drop_duplicates()\n",
    "train_xy.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path = '5e15730aa280850006f3d005'\n",
    "# train_wifi_pd_x = train_wifi_pd[train_wifi_pd.path==path]\n",
    "# train_y = train_xy[train_xy.path==path][['path','ts_waypoint','x','y']].drop_duplicates().reset_index(drop=True)\n",
    "# if len(train_y)==0:\n",
    "#     print(path,'have no waypoint')\n",
    "# if len(train_y)>0:\n",
    "#     ts_point_min = train_y.ts_waypoint.min()\n",
    "#     ts_point_max = train_y.ts_waypoint.max()\n",
    "#     tmp2 = train_wifi_pd_x[['timestamp']].drop_duplicates()\n",
    "#     tmp2 = tmp2[(tmp2.timestamp<=ts_point_max)&(tmp2.timestamp>=ts_point_min)]\n",
    "#     if len(tmp2)>0:\n",
    "#         T_rel = train_y['ts_waypoint']\n",
    "#         T_ref = tmp2['timestamp']\n",
    "#         xy_hat = scipy.interpolate.interp1d(T_rel, train_y[['x','y']], axis=0)(T_ref)\n",
    "#         tmp2['x'] = xy_hat[:,0]\n",
    "#         tmp2['y'] = xy_hat[:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10877/10877 [03:16<00:00, 55.30it/s]\n"
     ]
    }
   ],
   "source": [
    "import scipy.stats as stats\n",
    "import scipy\n",
    "train_all = []\n",
    "\n",
    "for path,train_wifi_pd_x in tqdm(train_wifi_pd.groupby('path')):\n",
    "    # path = '5e15730aa280850006f3d005'\n",
    "    train_y = train_xy[train_xy.path==path][['path','ts_waypoint','x','y']].drop_duplicates().reset_index(drop=True)\n",
    "    train_wifi_pd_x['ts_waypoint'] = 0\n",
    "    if len(train_y)==0:\n",
    "        print(path,'have no waypoint')\n",
    "    if len(train_y)>0:\n",
    "        ts_point_min = train_y.ts_waypoint.min()\n",
    "        ts_point_max = train_y.ts_waypoint.max()\n",
    "        tmp2 = train_wifi_pd_x[['timestamp']].drop_duplicates()\n",
    "        tmp2 = tmp2[(tmp2.timestamp<=ts_point_max)&(tmp2.timestamp>=ts_point_min)]\n",
    "        if len(tmp2)>0:\n",
    "            T_rel = train_y['ts_waypoint']\n",
    "            T_ref = tmp2['timestamp']\n",
    "            xy_hat = scipy.interpolate.interp1d(T_rel, train_y[['x','y']], axis=0)(T_ref)\n",
    "            tmp2['x'] = xy_hat[:,0]\n",
    "            tmp2['y'] = xy_hat[:,1]\n",
    "            tmp2['path'] = path\n",
    "            train_wifi_pd_x = pd.merge(train_wifi_pd_x,tmp2,how='left',on=['path','timestamp'])\n",
    "            train_all.append(train_wifi_pd_x)\n",
    "    \n",
    "train_all = pd.concat(train_all).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(258097, 15)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ###use nearest location\n",
    "# train_all = []\n",
    "\n",
    "# for path,train_wifi_pd_x in tqdm(train_wifi_pd.groupby('path')):\n",
    "#     # path = '5e15730aa280850006f3d005'\n",
    "#     train_y = train_xy[train_xy.path==path][['path','ts_waypoint','x','y']].drop_duplicates().reset_index(drop=True)\n",
    "#     train_wifi_pd_x['ts_waypoint'] = 0\n",
    "#     if len(train_y)==0:\n",
    "#         print(path,'have no waypoint')\n",
    "#     if len(train_y)>0:\n",
    "#         timestamplist = np.array(train_y.ts_waypoint)\n",
    "#         for ii in train_wifi_pd_x.index:\n",
    "#             distlist = np.abs(timestamplist-train_wifi_pd_x.loc[ii,'timestamp'])\n",
    "#             nearest_wp_index = np.argmin(distlist)\n",
    "#             train_wifi_pd_x.loc[ii,'ts_waypoint'] = int(timestamplist[nearest_wp_index])\n",
    "#         train_wifi_pd_x = pd.merge(train_wifi_pd_x,train_y,how='left',on=['path','ts_waypoint'])\n",
    "#         train_all.append(train_wifi_pd_x)\n",
    "    \n",
    "# train_all = pd.concat(train_all).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((11756, 15), (11756, 15))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all[train_all.x.isna()].shape,train_all[train_all.y.isna()].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_all = train_all[~train_all.x.isna()].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>ts_waypoint</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.088208</td>\n",
       "      <td>0</td>\n",
       "      <td>195.790623</td>\n",
       "      <td>93.465301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.040317</td>\n",
       "      <td>0</td>\n",
       "      <td>193.591333</td>\n",
       "      <td>92.973266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1560501001590</td>\n",
       "      <td>[18304, 19396, 7702, 7702, 19396, 7702, 12721,...</td>\n",
       "      <td>[10121, 57287, 31140, 61027, 55262, 22353, 603...</td>\n",
       "      <td>[3.1059258532748903, 3.1059258532748903, 2.810...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.238</td>\n",
       "      <td>0.268875</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.046341</td>\n",
       "      <td>0</td>\n",
       "      <td>191.394344</td>\n",
       "      <td>92.481745</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1560501003516</td>\n",
       "      <td>[19396, 7702, 19396, 18304, 7702, 7702, 7702, ...</td>\n",
       "      <td>[57287, 31140, 55262, 10121, 22353, 53865, 432...</td>\n",
       "      <td>[3.1059258532748903, 2.8107270221677836, 2.613...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.258</td>\n",
       "      <td>0.230216</td>\n",
       "      <td>0.252337</td>\n",
       "      <td>0.995631</td>\n",
       "      <td>0</td>\n",
       "      <td>189.177791</td>\n",
       "      <td>91.985848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1560501005442</td>\n",
       "      <td>[7702, 18304, 19396, 19396, 7702, 7702, 7702, ...</td>\n",
       "      <td>[31140, 10121, 55262, 57287, 43265, 61027, 612...</td>\n",
       "      <td>[2.8107270221677836, 2.6139278014297127, 2.613...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.282</td>\n",
       "      <td>0.210465</td>\n",
       "      <td>0.252337</td>\n",
       "      <td>0.963630</td>\n",
       "      <td>0</td>\n",
       "      <td>186.961238</td>\n",
       "      <td>91.489950</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "2  1560501001590  [18304, 19396, 7702, 7702, 19396, 7702, 12721,...   \n",
       "3  1560501003516  [19396, 7702, 19396, 18304, 7702, 7702, 7702, ...   \n",
       "4  1560501005442  [7702, 18304, 19396, 19396, 7702, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "2  [10121, 57287, 31140, 61027, 55262, 22353, 603...   \n",
       "3  [57287, 31140, 55262, 10121, 22353, 53865, 432...   \n",
       "4  [31140, 10121, 55262, 57287, 43265, 61027, 612...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "2  [3.1059258532748903, 3.1059258532748903, 2.810...   \n",
       "3  [3.1059258532748903, 2.8107270221677836, 2.613...   \n",
       "4  [2.8107270221677836, 2.6139278014297127, 2.613...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "2  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "3  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "4  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  wifi_median  wifi_std  ts_waypoint           x  \\\n",
       "0     0.206   0.353603     0.350737  1.088208            0  195.790623   \n",
       "1     0.220   0.299748     0.350737  1.040317            0  193.591333   \n",
       "2     0.238   0.268875     0.350737  1.046341            0  191.394344   \n",
       "3     0.258   0.230216     0.252337  0.995631            0  189.177791   \n",
       "4     0.282   0.210465     0.252337  0.963630            0  186.961238   \n",
       "\n",
       "           y  \n",
       "0  93.465301  \n",
       "1  92.973266  \n",
       "2  92.481745  \n",
       "3  91.985848  \n",
       "4  91.489950  "
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(246341, 15)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from sklearn.model_selection import StratifiedKFold\n",
    "# from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "# N_SPLITS = 10\n",
    "# SEED = 42\n",
    "# for fold, (trn_idx, val_idx) in enumerate(StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED).split(train_all['site'], train_all['site'])):\n",
    "#     train_all.loc[val_idx, 'fold'] = fold\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold\n",
    "N_SPLITS = 10\n",
    "\n",
    "path_list = train_all['path'].unique()\n",
    "folds = KFold(n_splits=N_SPLITS, shuffle=True, random_state=1024) \n",
    "for n_fold, (train_idx, valid_idx) in enumerate(folds.split(path_list), start=0):\n",
    "    train_all.loc[train_all['path'].isin(path_list[valid_idx]), 'fold'] = n_fold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all[train_all.path=='5dd3824044333f00067aa2c4'].fold.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all[train_all.site=='5c3c44b80379370013e0fd2b'].fold.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>ts_waypoint</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.088208</td>\n",
       "      <td>0</td>\n",
       "      <td>195.790623</td>\n",
       "      <td>93.465301</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.040317</td>\n",
       "      <td>0</td>\n",
       "      <td>193.591333</td>\n",
       "      <td>92.973266</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  wifi_median  wifi_std  ts_waypoint           x  \\\n",
       "0     0.206   0.353603     0.350737  1.088208            0  195.790623   \n",
       "1     0.220   0.299748     0.350737  1.040317            0  193.591333   \n",
       "\n",
       "           y  fold  \n",
       "0  93.465301   6.0  \n",
       "1  92.973266   6.0  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all['length'] = [len(xx) for xx in train_all['bssid']]\n",
    "# del train_all['length']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tmp1 = train_all[['x','y']].values\n",
    "# tmp1 = pd.DataFrame(list(zip(tmp1)),columns = ['xy'])\n",
    "# train_all = pd.concat([train_all,tmp1],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all_timestamp_min = train_all.timestamp.min()\n",
    "# train_all_timestamp_max = train_all.timestamp.max()\n",
    "# train_all['timestamp'] = (train_all['timestamp']-train_all_timestamp_min)/(train_all_timestamp_max-train_all_timestamp_min)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "# floor_map = {\"B2\":-2, \"B1\":-1, \"F1\":0, \"F2\": 1, \"F3\":2, \"F4\":3, \"F5\":4, \"F6\":5, \"F7\":6,\"F8\":7, \"F9\":8,\n",
    "#              \"1F\":0, \"2F\":1, \"3F\":2, \"4F\":3, \"5F\":4, \"6F\":5, \"7F\":6, \"8F\": 7, \"9F\":8}\n",
    "# train_all['floor'] = train_all['floor'].apply(lambda x: floor_map[x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>floor</th>\n",
       "      <th>site</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>ts_waypoint</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1560500997770</td>\n",
       "      <td>[7702, 19396, 18304, 19396, 7702, 7702, 19396,...</td>\n",
       "      <td>[61027, 55262, 10121, 57287, 45809, 53865, 261...</td>\n",
       "      <td>[3.204325463643926, 3.1059258532748903, 2.9091...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.206</td>\n",
       "      <td>0.353603</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.088208</td>\n",
       "      <td>0</td>\n",
       "      <td>195.790623</td>\n",
       "      <td>93.465301</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1560500999681</td>\n",
       "      <td>[18304, 7702, 7702, 19396, 19396, 7702, 7702, ...</td>\n",
       "      <td>[10121, 31140, 61027, 55262, 57287, 53865, 458...</td>\n",
       "      <td>[2.712327411798748, 2.712327411798748, 2.61392...</td>\n",
       "      <td>5d073b814a19c000086c558b</td>\n",
       "      <td>0.299386</td>\n",
       "      <td>F3</td>\n",
       "      <td>5c3c44b80379370013e0fd2b</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.299748</td>\n",
       "      <td>0.350737</td>\n",
       "      <td>1.040317</td>\n",
       "      <td>0</td>\n",
       "      <td>193.591333</td>\n",
       "      <td>92.973266</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1560500997770  [7702, 19396, 18304, 19396, 7702, 7702, 19396,...   \n",
       "1  1560500999681  [18304, 7702, 7702, 19396, 19396, 7702, 7702, ...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [61027, 55262, 10121, 57287, 45809, 53865, 261...   \n",
       "1  [10121, 31140, 61027, 55262, 57287, 53865, 458...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [3.204325463643926, 3.1059258532748903, 2.9091...   \n",
       "1  [2.712327411798748, 2.712327411798748, 2.61392...   \n",
       "\n",
       "                       path   floorNo floor                      site  \\\n",
       "0  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "1  5d073b814a19c000086c558b  0.299386    F3  5c3c44b80379370013e0fd2b   \n",
       "\n",
       "   wifi_len  wifi_mean  wifi_median  wifi_std  ts_waypoint           x  \\\n",
       "0     0.206   0.353603     0.350737  1.088208            0  195.790623   \n",
       "1     0.220   0.299748     0.350737  1.040317            0  193.591333   \n",
       "\n",
       "           y  fold  \n",
       "0  93.465301   6.0  \n",
       "1  92.973266   6.0  "
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1180</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3048</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0       1180  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1       3048  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std                      site  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547  "
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_wifi_pd.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(37678, 11)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all = test_wifi_pd.copy()\n",
    "test_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from dask.distributed import wait\n",
    "\n",
    "SENSORS = ['acce','acce_uncali','gyro',\n",
    "           'gyro_uncali','magn','magn_uncali','ahrs']\n",
    "\n",
    "NFEAS = {\n",
    "    'acce': 3,\n",
    "    'acce_uncali': 3,\n",
    "    'gyro': 3,\n",
    "    'gyro_uncali': 3,\n",
    "    'magn': 3,\n",
    "    'magn_uncali': 3,\n",
    "    'ahrs': 3,\n",
    "    'wifi': 1,\n",
    "    'ibeacon': 1,\n",
    "    'waypoint': 3\n",
    "}\n",
    "\n",
    "ACOLS = ['timestamp','x','y','z']\n",
    "        \n",
    "FIELDS = {\n",
    "    'acce': ACOLS,\n",
    "    'acce_uncali': ACOLS,\n",
    "    'gyro': ACOLS,\n",
    "    'gyro_uncali': ACOLS,\n",
    "    'magn': ACOLS,\n",
    "    'magn_uncali': ACOLS,\n",
    "    'ahrs': ACOLS,\n",
    "    'wifi': ['timestamp','ssid','bssid','rssi','last_timestamp'],\n",
    "    'ibeacon': ['timestamp','code','rssi','last_timestamp'],\n",
    "    'waypoint': ['timestamp','x','y']\n",
    "}\n",
    "\n",
    "def to_frame(data, col):\n",
    "    cols = FIELDS[col]\n",
    "    is_dummy = False\n",
    "    if data.shape[0]>0:\n",
    "        df = pd.DataFrame(data, columns=cols)\n",
    "    else:\n",
    "        df = create_dummy_df(cols)\n",
    "        is_dummy = True\n",
    "    for col in df.columns:\n",
    "        if 'timestamp' in col:\n",
    "            df[col] = df[col].astype('int64')\n",
    "    return df, is_dummy\n",
    "\n",
    "def create_dummy_df(cols):\n",
    "    df = pd.DataFrame()\n",
    "    for col in cols:\n",
    "        df[col] = [0]\n",
    "        if col in ['ssid','bssid']:\n",
    "            df[col] = df[col].map(str)\n",
    "    return df\n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ReadData:\n",
    "    acce: np.ndarray\n",
    "    acce_uncali: np.ndarray\n",
    "    gyro: np.ndarray\n",
    "    gyro_uncali: np.ndarray\n",
    "    magn: np.ndarray\n",
    "    magn_uncali: np.ndarray\n",
    "    ahrs: np.ndarray\n",
    "    wifi: np.ndarray\n",
    "    ibeacon: np.ndarray\n",
    "    waypoint: np.ndarray\n",
    "\n",
    "\n",
    "def read_data_file(data_filename):\n",
    "    acce = []\n",
    "    acce_uncali = []\n",
    "    gyro = []\n",
    "    gyro_uncali = []\n",
    "    magn = []\n",
    "    magn_uncali = []\n",
    "    ahrs = []\n",
    "    wifi = []\n",
    "    ibeacon = []\n",
    "    waypoint = []\n",
    "\n",
    "    with open(data_filename, 'r', encoding='utf-8') as file:\n",
    "        lines = file.readlines()\n",
    "\n",
    "    for line_data in lines:\n",
    "        line_data = line_data.strip()\n",
    "        if not line_data or line_data[0] == '#':\n",
    "            continue\n",
    "\n",
    "        line_data = line_data.split('\\t')\n",
    "\n",
    "        if line_data[1] == 'TYPE_ACCELEROMETER':\n",
    "            acce.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_ACCELEROMETER_UNCALIBRATED':\n",
    "            acce_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_GYROSCOPE':\n",
    "            gyro.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_GYROSCOPE_UNCALIBRATED':\n",
    "            gyro_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_MAGNETIC_FIELD':\n",
    "            magn.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_MAGNETIC_FIELD_UNCALIBRATED':\n",
    "            magn_uncali.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_ROTATION_VECTOR':\n",
    "            if len(line_data)>=5:\n",
    "                ahrs.append([int(line_data[0]), float(line_data[2]), float(line_data[3]), float(line_data[4])])\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_WIFI':\n",
    "            sys_ts = line_data[0]\n",
    "            ssid = line_data[2]\n",
    "            bssid = line_data[3]\n",
    "            rssi = line_data[4]\n",
    "            lastseen_ts = line_data[6]\n",
    "            wifi_data = [sys_ts, ssid, bssid, rssi, lastseen_ts]\n",
    "            wifi.append(wifi_data)\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_BEACON':\n",
    "            ts = line_data[0]\n",
    "            uuid = line_data[2]\n",
    "            major = line_data[3]\n",
    "            minor = line_data[4]\n",
    "            rssi = line_data[6]\n",
    "            lastts = line_data[-1]\n",
    "            ibeacon_data = [ts, '_'.join([uuid, major, minor]), rssi, lastts]\n",
    "            ibeacon.append(ibeacon_data)\n",
    "            continue\n",
    "\n",
    "        if line_data[1] == 'TYPE_WAYPOINT':\n",
    "            waypoint.append([int(line_data[0]), float(line_data[2]), float(line_data[3])])\n",
    "\n",
    "    acce = np.array(acce)\n",
    "    acce_uncali = np.array(acce_uncali)\n",
    "    gyro = np.array(gyro)\n",
    "    gyro_uncali = np.array(gyro_uncali)\n",
    "    magn = np.array(magn)\n",
    "    magn_uncali = np.array(magn_uncali)\n",
    "    ahrs = np.array(ahrs)\n",
    "    wifi = np.array(wifi)\n",
    "    ibeacon = np.array(ibeacon)\n",
    "    waypoint = np.array(waypoint)\n",
    "\n",
    "    return ReadData(acce, acce_uncali, gyro, gyro_uncali, magn, magn_uncali, ahrs, wifi, ibeacon, waypoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_dfs(PATH, test_files):\n",
    "    dtest = get_test_df(PATH)\n",
    "    buildings = set(dtest['building'].values.tolist())\n",
    "    dws = {}\n",
    "    ntest_files = []\n",
    "    for fname in tqdm(test_files):\n",
    "        path = fname.split('/')[-1].split('.')[0]\n",
    "        mask = dtest['path'] == path\n",
    "        dws[fname] = dtest.loc[mask, ['timestamp','x','y','floor','building','site_path_timestamp']].copy().reset_index(drop=True)\n",
    "        ntest_files.append(fname)\n",
    "    return dws\n",
    "\n",
    "def get_test_df(PATH):\n",
    "    dtest = pd.read_csv(f'{PATH}/sample_submission.csv')\n",
    "    dtest['building'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[0])\n",
    "    dtest['path'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[1])\n",
    "    dtest['timestamp'] = dtest['site_path_timestamp'].apply(lambda x: x.split('_')[2])\n",
    "    dtest['timestamp'] = dtest['timestamp'].astype('int64')\n",
    "    dtest = dtest.sort_values(['path','timestamp']).reset_index(drop=True)\n",
    "    return dtest\n",
    "\n",
    "def get_time_gap(name):\n",
    "    data = read_data_file(name)\n",
    "    db,no_ibeacon = to_frame(data.ibeacon,'ibeacon')\n",
    "#     print(db,no_ibeacon)\n",
    "    \n",
    "    if no_ibeacon==0:\n",
    "        gap = db['last_timestamp'] - db['timestamp']\n",
    "        assert gap.unique().shape[0]==1\n",
    "        return gap.values[0],no_ibeacon\n",
    "    \n",
    "    if no_ibeacon==1:\n",
    "        # Group wifis by timestamp\n",
    "        wifi_groups = pd.DataFrame(data.wifi).groupby(0)   \n",
    "        # Find which one is the most recent of all time points.\n",
    "        est_ts = (wifi_groups[4].max().astype(int) - wifi_groups[0].max().astype(int)).max() \n",
    "        return est_ts,no_ibeacon\n",
    "\n",
    "    \n",
    "\n",
    "def fix_timestamp_test(df, gap):\n",
    "    df['real_timestamp'] = df['timestamp'] + gap\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../input/indoor-location-navigation/test/00ff0c9a71cc37a2ebdd0f05.txt',\n",
       " '../input/indoor-location-navigation/test/01c41f1aeba5c48c2c4dd568.txt',\n",
       " '../input/indoor-location-navigation/test/030b3d94de8acae7c936563d.txt',\n",
       " '../input/indoor-location-navigation/test/0389421238a7e2839701df0f.txt']"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_files_ori = glob.glob('../input/indoor-location-navigation/test/*.txt')\n",
    "test_files_ori[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table style=\"border: 2px solid white;\">\n",
       "<tr>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Client</h3>\n",
       "<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n",
       "  <li><b>Scheduler: </b>tcp://127.0.0.1:36641</li>\n",
       "  <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
       "</ul>\n",
       "</td>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Cluster</h3>\n",
       "<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n",
       "  <li><b>Workers: </b>8</li>\n",
       "  <li><b>Cores: </b>8</li>\n",
       "  <li><b>Memory: </b>66.71 GB</li>\n",
       "</ul>\n",
       "</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<Client: 'tcp://127.0.0.1:36641' processes=8 threads=8, memory=66.71 GB>"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dask\n",
    "from dask.distributed import Client, wait, LocalCluster\n",
    "\n",
    "# set n_workers to number of cores\n",
    "client = Client(n_workers=8, \n",
    "                threads_per_worker=1)\n",
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 626/626 [00:00<00:00, 10654.38it/s]\n",
      "100%|██████████| 626/626 [00:17<00:00, 34.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.87 s, sys: 169 ms, total: 3.04 s\n",
      "Wall time: 18 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "futures = []\n",
    "for fname in tqdm(test_files_ori, total=len(test_files_ori)):\n",
    "    f = client.submit(get_time_gap,fname)\n",
    "    futures.append(f)\n",
    "    \n",
    "testpath2gap = {}\n",
    "for f,fname in tqdm(zip(futures, test_files_ori), total=len(test_files_ori)):\n",
    "    testpath2gap[fname.split('/')[-1].replace('.txt','')] = f.result()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_all['timestamp'] = [xx+testpath2gap[yy][0] for (xx,yy) in zip(test_all['timestamp'],test_all['path'])]\n",
    "# test_all['ts_waypoint'] = [xx+testpath2gap[yy][0] for (xx,yy) in zip(test_all['ts_waypoint'],test_all['path'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_all['timestamp'] = (test_all['timestamp']-train_all_timestamp_min)/(train_all_timestamp_max-train_all_timestamp_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1573190312033</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1573190313901</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       timestamp                                               ssid  \\\n",
       "0  1573190312033  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1  1573190313901  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std                      site  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547  "
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ss2 = StandardScaler()\n",
    "ss2.fit(train_all.loc[:,['timestamp']])\n",
    "train_all.loc[:,['timestamp']] = ss2.transform(train_all.loc[:,['timestamp']])\n",
    "test_all.loc[:,['timestamp']] = ss2.transform(test_all.loc[:,['timestamp']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_all_floor_min = train_all.floor.min()\n",
    "# train_all_floor_max = train_all.floor.max()\n",
    "# train_all['floor'] = (train_all['floor']-train_all_floor_min)/(train_all_floor_max-train_all_floor_min)\n",
    "# test_all['floor'] = (test_all['floor']-train_all_floor_min)/(train_all_floor_max-train_all_floor_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "sitelist = list(sorted(set(train_all.site)))\n",
    "sitedict = dict(zip(sitelist,range(len(sitelist))))\n",
    "train_all['site_id'] = train_all['site'].apply(lambda x: sitedict[x])\n",
    "test_all['site_id'] = test_all['site'].apply(lambda x: sitedict[x])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MCRMSE(y_true, y_pred):\n",
    "    colwise_mse = tf.reduce_mean(tf.square(y_true - y_pred), axis=1)\n",
    "    return tf.reduce_mean(tf.sqrt(colwise_mse), axis=1)\n",
    "\n",
    "def gru_layer(hidden_dim, dropout):\n",
    "    return L.Bidirectional(L.GRU(\n",
    "        hidden_dim, dropout=dropout, return_sequences=True, kernel_initializer='orthogonal'))\n",
    "\n",
    "def pandas_list_to_array(df):\n",
    "    \"\"\"\n",
    "    Input: dataframe of shape (x, y), containing list of length l\n",
    "    Return: np.array of shape (x, l, y)\n",
    "    \"\"\"\n",
    "    \n",
    "    return np.transpose(\n",
    "        np.array(df.values.tolist()),\n",
    "        (0, 2, 1)\n",
    "    )\n",
    "\n",
    "def preprocess_inputs(df, cols=['ssid','bssid', 'rssi']):\n",
    "    return pandas_list_to_array(\n",
    "        df[cols]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model_time(embed_size, seq_len=100, pred_len=2, dropout=0.5, \n",
    "                sp_dropout=0.2, embed_dim=200, hidden_dim=256, n_layers=2):\n",
    "    inputs = L.Input(shape=(seq_len, 2))\n",
    "    input_time = L.Input(shape = (1,))\n",
    "    \n",
    "\n",
    "    categorical_fea = inputs[:, :, :1]\n",
    "    numerical_fea = inputs[:, :, 1:]\n",
    "\n",
    "    embed = L.Embedding(input_dim=embed_size, output_dim=embed_dim)(categorical_fea)\n",
    "    reshaped = tf.reshape(embed, shape=(-1, embed.shape[1],  embed.shape[2] * embed.shape[3]))\n",
    "    reshaped = L.SpatialDropout1D(sp_dropout)(reshaped)\n",
    "    \n",
    "    \n",
    "    hidden = L.concatenate([reshaped, numerical_fea], axis=2)\n",
    "    \n",
    "    for x in range(n_layers):\n",
    "        hidden = gru_layer(hidden_dim, dropout)(hidden)\n",
    "    \n",
    "    # Since we are only making predictions on the first part of each sequence, \n",
    "    # we have to truncate it\n",
    "    truncated = hidden[:, :pred_len]\n",
    "    truncated = L.Flatten()(truncated)\n",
    "    truncated = L.concatenate([truncated, input_time], axis=1)\n",
    "\n",
    "    out = L.Dense(2, activation='linear')(truncated)\n",
    "\n",
    "        \n",
    "    model = tf.keras.Model(inputs=[inputs,input_time], outputs=out)\n",
    "    model.compile(tf.optimizers.Adam(), loss='mse')\n",
    "    \n",
    "    return model\n",
    "\n",
    "def get_embed_size(n_cat):\n",
    "    return min(600, round(1.6 * n_cat ** .56))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model_mix(sid_size,bssid_size,site_size, seq_len=100, pred_len=2, dropout=0.2, \n",
    "                sp_dropout=0.1, embed_dim=64, hidden_dim=128, n_layers=3,lr=0.001):\n",
    "    inputs = L.Input(shape=(seq_len, 3))\n",
    "    input_time = L.Input(shape = (4,))\n",
    "    input_site = L.Input(shape = (1,))\n",
    "        \n",
    "    categorical_fea1 = inputs[:, :, :1]\n",
    "    categorical_fea2 = inputs[:, :, 1:2]\n",
    "    numerical_fea = inputs[:, :, 2:]\n",
    "    \n",
    "\n",
    "    embed = L.Embedding(input_dim=sid_size, output_dim=embed_dim)(categorical_fea1)\n",
    "    reshaped = tf.reshape(embed, shape=(-1, embed.shape[1],  embed.shape[2] * embed.shape[3]))\n",
    "    reshaped = L.SpatialDropout1D(sp_dropout)(reshaped)\n",
    "    \n",
    "    embed2 = L.Embedding(input_dim=bssid_size, output_dim=embed_dim)(categorical_fea2)\n",
    "    reshaped2 = tf.reshape(embed2, shape=(-1, embed2.shape[1],  embed2.shape[2] * embed2.shape[3]))\n",
    "    reshaped2 = L.SpatialDropout1D(sp_dropout)(reshaped2)\n",
    "    \n",
    "    \n",
    "    hidden = L.concatenate([reshaped, reshaped2, numerical_fea], axis=2)\n",
    "    \n",
    "    for x in range(n_layers):\n",
    "        hidden = gru_layer(hidden_dim, dropout)(hidden)\n",
    "    \n",
    "    # Since we are only making predictions on the first part of each sequence, \n",
    "    # we have to truncate it\n",
    "    truncated = hidden[:, :pred_len]\n",
    "    truncated = L.Flatten()(truncated)\n",
    "    \n",
    "    embed_site = L.Embedding(input_dim=site_size, output_dim=1)(input_site)\n",
    "    embed_site = L.Flatten()(embed_site)\n",
    "        \n",
    "    truncated = L.concatenate([truncated, input_time,embed_site], axis=1)\n",
    "    \n",
    "    #out = L.Dense(32, activation='linear')(truncated)\n",
    "\n",
    "    out = L.Dense(2, activation='linear')(truncated)\n",
    "        \n",
    "    model = tf.keras.Model(inputs=[inputs,input_time,input_site], outputs=out)\n",
    "    model.compile(tf.optimizers.Adam(lr), loss='mse')\n",
    "    \n",
    "    return model\n",
    "\n",
    "def get_embed_size(n_cat):\n",
    "    return min(600, round(1.6 * n_cat ** .56))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# def build_model_time_floors_site(ssid_size,bssid_size,site_size,seq_len=100,dropout=0.5, \n",
    "#                 sp_dropout=0.2, embed_dim=64, hidden_dim=256, n_layers=2):\n",
    "#     inputs = L.Input(shape=(seq_len, 2))\n",
    "#     input_time = L.Input(shape = (2,)) ##time and floor\n",
    "#     input_site = L.Input(shape = (1,)) \n",
    "\n",
    "# #     ssid_fea = inputs[:, :, :1]\n",
    "#     bssid_fea = inputs[:,:,:1]\n",
    "#     rssi_fea = inputs[:,:,1:]\n",
    "\n",
    "# #     embed_ssid = L.Embedding(input_dim=ssid_size, output_dim=32)(ssid_fea)\n",
    "#     embed_bssid = L.Embedding(input_dim=bssid_size, output_dim=64)(bssid_fea)\n",
    "#     embed_site = L.Embedding(input_dim=site_size, output_dim=3)(input_site)\n",
    "\n",
    "# #     embed_ssid = L.Flatten()(embed_ssid)\n",
    "#     embed_bssid = L.Flatten()(embed_bssid)\n",
    "#     embed_site = L.Flatten()(embed_site)\n",
    "#     rssi_fea = L.Flatten()(rssi_fea)\n",
    "\n",
    "#     #reshaped = tf.reshape(embed, shape=(-1, embed.shape[1],  embed.shape[2] * embed.shape[3]))\n",
    "#     #reshaped = L.SpatialDropout1D(sp_dropout)(reshaped)\n",
    "    \n",
    "    \n",
    "#     hidden = L.concatenate([input_time,embed_bssid,rssi_fea], axis=1)\n",
    "#     hidden = L.Dropout(0.2)(hidden)\n",
    "#     print(hidden.shape)\n",
    "#     x = L.Reshape((1, -1))(hidden)\n",
    "    \n",
    "#     x = L.BatchNormalization()(x)\n",
    "#     x = L.LSTM(128, dropout=0.3, recurrent_dropout=0.3, return_sequences=True, activation='relu')(x)\n",
    "#     x = L.LSTM(16, dropout=0.1, return_sequences=False, activation='relu')(x)\n",
    "\n",
    "#     out = L.Dense(2, activation='linear')(x)\n",
    "\n",
    "        \n",
    "#     model = tf.keras.Model(inputs=[inputs,input_time,input_site], outputs=out)\n",
    "#     model.compile(tf.optimizers.Adam(), loss='mse')\n",
    "    \n",
    "#     return model\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# with open('train_all.pickle','wb') as fw:\n",
    "#     pickle.dump(train_all,fw)\n",
    "# with open('test_all.pickle','wb') as fw:\n",
    "#     pickle.dump(test_all,fw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "begin fold: 0\n",
      "fold 0 7.73294929426513\n",
      "150.92601263675743\n",
      "elasped time: 84.61294651031494\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "t1 = time.time()\n",
    "pred_cols = ['x','y']\n",
    "train_inputs = preprocess_inputs(train_all,cols=['ssid', 'bssid', 'rssi'])\n",
    "train_inputs_time = train_all[['timestamp','floorNo','wifi_len','wifi_mean']].values\n",
    "train_inputs_site = train_all['site_id'].values\n",
    "train_labels = train_all[pred_cols].values\n",
    "test_inputs = preprocess_inputs(test_all,cols=['ssid','bssid', 'rssi'])\n",
    "test_inputs_time = test_all[['timestamp','floorNo','wifi_len','wifi_mean']].values\n",
    "test_inputs_site = test_all['site_id'].values\n",
    "\n",
    "\n",
    "    \n",
    "    \n",
    "x_test = test_inputs\n",
    "x_test_time = test_inputs_time\n",
    "x_test_site = test_inputs_site\n",
    "\n",
    "oof_xy = np.zeros(train_labels.shape)\n",
    "y_test_pred = 0\n",
    "for fold_id in range(N_SPLITS):\n",
    "    trn_idx = train_all[train_all.fold!=fold_id].index.tolist()\n",
    "    val_idx = train_all[train_all.fold==fold_id].index.tolist()\n",
    "    print('begin fold:',fold_id)\n",
    "    x_train, x_val = train_inputs[trn_idx],train_inputs[val_idx]\n",
    "    x_train_time, x_val_time = train_inputs_time[trn_idx],train_inputs_time[val_idx]\n",
    "    x_train_site, x_val_site = train_inputs_site[trn_idx],train_inputs_site[val_idx]\n",
    "    y_train, y_val = train_labels[trn_idx],train_labels[val_idx]\n",
    "    \n",
    "    model = build_model_mix(len(ssiddict),len(bssiddict),len(sitedict),seqlen,lr=0.001)\n",
    "#     model.load_weights('rnn_model_v4/model_allsite_fold{}_times2.h5'.format(fold_id))\n",
    "    history = model.fit(\n",
    "            [x_train,x_train_time,x_train_site], y_train,\n",
    "            validation_data=([x_val,x_val_time,x_val_site], y_val),\n",
    "            batch_size=128,\n",
    "            epochs=100,\n",
    "            verbose=1,\n",
    "            callbacks=[\n",
    "                tf.keras.callbacks.ReduceLROnPlateau(patience=5),\n",
    "                tf.keras.callbacks.ModelCheckpoint('rnn_model_wifi/model_fold{}.h5'.format(fold_id)),\n",
    "                tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=1e-4,\n",
    "                                  patience=5, mode='min', restore_best_weights=True)\n",
    "            ]\n",
    "        )\n",
    "#         model.load_weights('rnn_model_wifi/model_fold{}.h5')\n",
    "    y_val_pred = model.predict([x_val,x_val_time,x_val_site])\n",
    "    y_test_pred += model.predict([x_test,x_test_time,x_test_site])\n",
    "    oof_xy[val_idx] = y_val_pred\n",
    "    print('fold',fold_id, np.mean(np.sqrt(np.sum((y_val-y_val_pred)**2,axis=1))))\n",
    "    break\n",
    "y_test_pred = y_test_pred/(fold_id + 1)    \n",
    "train_labels_inv = (pd.DataFrame(train_labels[:,:],columns = ['x','y']))\n",
    "oof_xy_pred_inv = (pd.DataFrame(oof_xy[:,:],columns = ['x','y']))\n",
    "y_test_pred_inv = (pd.DataFrame(y_test_pred[:,:],columns = ['x','y']))  \n",
    "print(np.mean(np.sqrt(np.sum((train_labels_inv-oof_xy_pred_inv)**2,axis=1))))\n",
    "\n",
    "t2 = time.time()\n",
    "print('elasped time:', t2 - t1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_all[['x','y']] = y_test_pred_inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>ssid</th>\n",
       "      <th>bssid</th>\n",
       "      <th>rssi</th>\n",
       "      <th>path</th>\n",
       "      <th>floorNo</th>\n",
       "      <th>wifi_len</th>\n",
       "      <th>wifi_mean</th>\n",
       "      <th>wifi_median</th>\n",
       "      <th>wifi_std</th>\n",
       "      <th>site</th>\n",
       "      <th>site_id</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>[7007, 9522, 15215, 18669, 15215, 19396, 4851,...</td>\n",
       "      <td>[35106, 10783, 39335, 4531, 48757, 19211, 1176...</td>\n",
       "      <td>[1.9251305288464635, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.038</td>\n",
       "      <td>0.024464</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>1.033093</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>19</td>\n",
       "      <td>49.430897</td>\n",
       "      <td>89.246811</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>[18669, 9522, 7007, 19396, 15215, 15215, 1264,...</td>\n",
       "      <td>[4531, 10783, 35106, 19211, 39335, 48757, 6030...</td>\n",
       "      <td>[2.1219297495845346, 1.4331324770012857, 1.334...</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>0.845957</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.075218</td>\n",
       "      <td>-0.338061</td>\n",
       "      <td>0.991529</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>19</td>\n",
       "      <td>71.179886</td>\n",
       "      <td>87.176270</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                                               ssid  \\\n",
       "0   0.345764  [7007, 9522, 15215, 18669, 15215, 19396, 4851,...   \n",
       "1   0.345765  [18669, 9522, 7007, 19396, 15215, 15215, 1264,...   \n",
       "\n",
       "                                               bssid  \\\n",
       "0  [35106, 10783, 39335, 4531, 48757, 19211, 1176...   \n",
       "1  [4531, 10783, 35106, 19211, 39335, 48757, 6030...   \n",
       "\n",
       "                                                rssi  \\\n",
       "0  [1.9251305288464635, 1.4331324770012857, 1.334...   \n",
       "1  [2.1219297495845346, 1.4331324770012857, 1.334...   \n",
       "\n",
       "                       path   floorNo  wifi_len  wifi_mean  wifi_median  \\\n",
       "0  00ff0c9a71cc37a2ebdd0f05  0.845957     0.038   0.024464    -0.338061   \n",
       "1  00ff0c9a71cc37a2ebdd0f05  0.845957     0.040   0.075218    -0.338061   \n",
       "\n",
       "   wifi_std                      site  site_id          x          y  \n",
       "0  1.033093  5da1389e4db8ce0c98bd0547       19  49.430897  89.246811  \n",
       "1  0.991529  5da1389e4db8ce0c98bd0547       19  71.179886  87.176270  "
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_all.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>t1_wifi</th>\n",
       "      <th>path_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>49.430897</td>\n",
       "      <td>89.246811</td>\n",
       "      <td>1180.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.179886</td>\n",
       "      <td>87.176270</td>\n",
       "      <td>3048.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.408737</td>\n",
       "      <td>86.979248</td>\n",
       "      <td>4924.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.819069</td>\n",
       "      <td>83.849525</td>\n",
       "      <td>6816.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.345767</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.560272</td>\n",
       "      <td>86.284660</td>\n",
       "      <td>8693.0</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   timestamp                      path                      site          x  \\\n",
       "0   0.345764  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  49.430897   \n",
       "1   0.345765  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  71.179886   \n",
       "2   0.345766  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  71.408737   \n",
       "3   0.345766  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  71.819069   \n",
       "4   0.345767  00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547  71.560272   \n",
       "\n",
       "           y  t1_wifi                                            path_id  \n",
       "0  89.246811   1180.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "1  87.176270   3048.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "2  86.979248   4924.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "3  83.849525   6816.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  \n",
       "4  86.284660   8693.0  5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  "
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = test_all[['timestamp','path','site','x','y']]\n",
    "result['t1_wifi'] = ss2.inverse_transform(result['timestamp'])\n",
    "\n",
    "result['t1_wifi'] = [xx-testpath2gap[yy][0] for (xx,yy) in zip(result['t1_wifi'],result['path'])]\n",
    "result['path_id'] = result['site']+'_'+result['path']\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "      <th>path</th>\n",
       "      <th>site</th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>t1_wifi</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>path_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345764</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>49.430897</td>\n",
       "      <td>89.246811</td>\n",
       "      <td>1180.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345765</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.179886</td>\n",
       "      <td>87.176270</td>\n",
       "      <td>3048.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.408737</td>\n",
       "      <td>86.979248</td>\n",
       "      <td>4924.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345766</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.819069</td>\n",
       "      <td>83.849525</td>\n",
       "      <td>6816.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05</th>\n",
       "      <td>0.345767</td>\n",
       "      <td>00ff0c9a71cc37a2ebdd0f05</td>\n",
       "      <td>5da1389e4db8ce0c98bd0547</td>\n",
       "      <td>71.560272</td>\n",
       "      <td>86.284660</td>\n",
       "      <td>8693.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   timestamp  \\\n",
       "path_id                                                        \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345764   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345765   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345766   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345766   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   0.345767   \n",
       "\n",
       "                                                                       path  \\\n",
       "path_id                                                                       \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  00ff0c9a71cc37a2ebdd0f05   \n",
       "\n",
       "                                                                       site  \\\n",
       "path_id                                                                       \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  5da1389e4db8ce0c98bd0547   \n",
       "\n",
       "                                                           x          y  \\\n",
       "path_id                                                                   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  49.430897  89.246811   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  71.179886  87.176270   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  71.408737  86.979248   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  71.819069  83.849525   \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05  71.560272  86.284660   \n",
       "\n",
       "                                                   t1_wifi  \n",
       "path_id                                                     \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   1180.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   3048.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   4924.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   6816.0  \n",
       "5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0f05   8693.0  "
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# result['path_id'] = ['_'.join(xx.split('_')[:2]) for xx in result.site_path_timestamp]\n",
    "# result['t1_wifi'] = [int(xx.split('_')[2]) for xx in result.site_path_timestamp]\n",
    "# del result['site_path_timestamp']\n",
    "result.set_index('path_id', inplace=True)\n",
    "result.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.transform import Rotation as R\n",
    "from PIL import Image\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import plotly.graph_objs as go\n",
    "from pathlib import Path\n",
    "import scipy.signal as signal\n",
    "import json\n",
    "import seaborn as sns # visualization\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import matplotlib.pyplot as plt  # visualization\n",
    "import numpy as np  # linear algebra\n",
    "import random\n",
    "import pandas as pd\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "\n",
    "def split_ts_seq(ts_seq, sep_ts):\n",
    "    \"\"\"\n",
    "\n",
    "    :param ts_seq:\n",
    "    :param sep_ts:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    tss = ts_seq[:, 0].astype(float)\n",
    "    unique_sep_ts = np.unique(sep_ts)\n",
    "    ts_seqs = []\n",
    "    start_index = 0\n",
    "    for i in range(0, unique_sep_ts.shape[0]):\n",
    "        end_index = np.searchsorted(tss, unique_sep_ts[i], side='right')\n",
    "        if start_index == end_index:\n",
    "            continue\n",
    "        ts_seqs.append(ts_seq[start_index:end_index, :].copy())\n",
    "        start_index = end_index\n",
    "\n",
    "    # tail data\n",
    "    if start_index < ts_seq.shape[0]:\n",
    "        ts_seqs.append(ts_seq[start_index:, :].copy())\n",
    "\n",
    "    return ts_seqs\n",
    "\n",
    "\n",
    "def correct_trajectory(original_xys, end_xy):\n",
    "    \"\"\"\n",
    "\n",
    "    :param original_xys: numpy ndarray, shape(N, 2)\n",
    "    :param end_xy: numpy ndarray, shape(1, 2)\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    corrected_xys = np.zeros((0, 2))\n",
    "\n",
    "    A = original_xys[0, :]\n",
    "    B = end_xy\n",
    "    Bp = original_xys[-1, :]\n",
    "\n",
    "    angle_BAX = np.arctan2(B[1] - A[1], B[0] - A[0])\n",
    "    angle_BpAX = np.arctan2(Bp[1] - A[1], Bp[0] - A[0])\n",
    "    angle_BpAB = angle_BpAX - angle_BAX\n",
    "    AB = np.sqrt(np.sum((B - A) ** 2))\n",
    "    ABp = np.sqrt(np.sum((Bp - A) ** 2))\n",
    "\n",
    "    corrected_xys = np.append(corrected_xys, [A], 0)\n",
    "    for i in np.arange(1, np.size(original_xys, 0)):\n",
    "        angle_CpAX = np.arctan2(original_xys[i, 1] - A[1], original_xys[i, 0] - A[0])\n",
    "\n",
    "        angle_CAX = angle_CpAX - angle_BpAB\n",
    "\n",
    "        ACp = np.sqrt(np.sum((original_xys[i, :] - A) ** 2))\n",
    "\n",
    "        AC = ACp * AB / ABp\n",
    "\n",
    "        delta_C = np.array([AC * np.cos(angle_CAX), AC * np.sin(angle_CAX)])\n",
    "\n",
    "        C = delta_C + A\n",
    "\n",
    "        corrected_xys = np.append(corrected_xys, [C], 0)\n",
    "\n",
    "    return corrected_xys\n",
    "\n",
    "\n",
    "def correct_positions(rel_positions, reference_positions):\n",
    "    \"\"\"\n",
    "\n",
    "    :param rel_positions:\n",
    "    :param reference_positions:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    rel_positions_list = split_ts_seq(rel_positions, reference_positions[:, 0])\n",
    "    if len(rel_positions_list) != reference_positions.shape[0] - 1:\n",
    "        # print(f'Rel positions list size: {len(rel_positions_list)}, ref positions size: {reference_positions.shape[0]}')\n",
    "        del rel_positions_list[-1]\n",
    "    assert len(rel_positions_list) == reference_positions.shape[0] - 1\n",
    "\n",
    "    corrected_positions = np.zeros((0, 3))\n",
    "    for i, rel_ps in enumerate(rel_positions_list):\n",
    "        start_position = reference_positions[i]\n",
    "        end_position = reference_positions[i + 1]\n",
    "        abs_ps = np.zeros(rel_ps.shape)\n",
    "        abs_ps[:, 0] = rel_ps[:, 0]\n",
    "        # abs_ps[:, 1:3] = rel_ps[:, 1:3] + start_position[1:3]\n",
    "        abs_ps[0, 1:3] = rel_ps[0, 1:3] + start_position[1:3]\n",
    "        for j in range(1, rel_ps.shape[0]):\n",
    "            abs_ps[j, 1:3] = abs_ps[j-1, 1:3] + rel_ps[j, 1:3]\n",
    "        abs_ps = np.insert(abs_ps, 0, start_position, axis=0)\n",
    "        corrected_xys = correct_trajectory(abs_ps[:, 1:3], end_position[1:3])\n",
    "        corrected_ps = np.column_stack((abs_ps[:, 0], corrected_xys))\n",
    "        if i == 0:\n",
    "            corrected_positions = np.append(corrected_positions, corrected_ps, axis=0)\n",
    "        else:\n",
    "            corrected_positions = np.append(corrected_positions, corrected_ps[1:], axis=0)\n",
    "\n",
    "    corrected_positions = np.array(corrected_positions)\n",
    "\n",
    "    return corrected_positions\n",
    "\n",
    "\n",
    "def init_parameters_filter(sample_freq, warmup_data, cut_off_freq=2):\n",
    "    order = 4\n",
    "    filter_b, filter_a = signal.butter(order, cut_off_freq / (sample_freq / 2), 'low', False)\n",
    "    zf = signal.lfilter_zi(filter_b, filter_a)\n",
    "    _, zf = signal.lfilter(filter_b, filter_a, warmup_data, zi=zf)\n",
    "    _, filter_zf = signal.lfilter(filter_b, filter_a, warmup_data, zi=zf)\n",
    "\n",
    "    return filter_b, filter_a, filter_zf\n",
    "\n",
    "\n",
    "def get_rotation_matrix_from_vector(rotation_vector):\n",
    "    q1 = rotation_vector[0]\n",
    "    q2 = rotation_vector[1]\n",
    "    q3 = rotation_vector[2]\n",
    "\n",
    "    if rotation_vector.size >= 4:\n",
    "        q0 = rotation_vector[3]\n",
    "    else:\n",
    "        q0 = 1 - q1*q1 - q2*q2 - q3*q3\n",
    "        if q0 > 0:\n",
    "            q0 = np.sqrt(q0)\n",
    "        else:\n",
    "            q0 = 0\n",
    "\n",
    "    sq_q1 = 2 * q1 * q1\n",
    "    sq_q2 = 2 * q2 * q2\n",
    "    sq_q3 = 2 * q3 * q3\n",
    "    q1_q2 = 2 * q1 * q2\n",
    "    q3_q0 = 2 * q3 * q0\n",
    "    q1_q3 = 2 * q1 * q3\n",
    "    q2_q0 = 2 * q2 * q0\n",
    "    q2_q3 = 2 * q2 * q3\n",
    "    q1_q0 = 2 * q1 * q0\n",
    "\n",
    "    R = np.zeros((9,))\n",
    "    if R.size == 9:\n",
    "        R[0] = 1 - sq_q2 - sq_q3\n",
    "        R[1] = q1_q2 - q3_q0\n",
    "        R[2] = q1_q3 + q2_q0\n",
    "\n",
    "        R[3] = q1_q2 + q3_q0\n",
    "        R[4] = 1 - sq_q1 - sq_q3\n",
    "        R[5] = q2_q3 - q1_q0\n",
    "\n",
    "        R[6] = q1_q3 - q2_q0\n",
    "        R[7] = q2_q3 + q1_q0\n",
    "        R[8] = 1 - sq_q1 - sq_q2\n",
    "\n",
    "        R = np.reshape(R, (3, 3))\n",
    "    elif R.size == 16:\n",
    "        R[0] = 1 - sq_q2 - sq_q3\n",
    "        R[1] = q1_q2 - q3_q0\n",
    "        R[2] = q1_q3 + q2_q0\n",
    "        R[3] = 0.0\n",
    "\n",
    "        R[4] = q1_q2 + q3_q0\n",
    "        R[5] = 1 - sq_q1 - sq_q3\n",
    "        R[6] = q2_q3 - q1_q0\n",
    "        R[7] = 0.0\n",
    "\n",
    "        R[8] = q1_q3 - q2_q0\n",
    "        R[9] = q2_q3 + q1_q0\n",
    "        R[10] = 1 - sq_q1 - sq_q2\n",
    "        R[11] = 0.0\n",
    "\n",
    "        R[12] = R[13] = R[14] = 0.0\n",
    "        R[15] = 1.0\n",
    "\n",
    "        R = np.reshape(R, (4, 4))\n",
    "\n",
    "    return R\n",
    "\n",
    "\n",
    "def get_orientation(R):\n",
    "    flat_R = R.flatten()\n",
    "    values = np.zeros((3,))\n",
    "    if np.size(flat_R) == 9:\n",
    "        values[0] = np.arctan2(flat_R[1], flat_R[4])\n",
    "        values[1] = np.arcsin(-flat_R[7])\n",
    "        values[2] = np.arctan2(-flat_R[6], flat_R[8])\n",
    "    else:\n",
    "        values[0] = np.arctan2(flat_R[1], flat_R[5])\n",
    "        values[1] = np.arcsin(-flat_R[9])\n",
    "        values[2] = np.arctan2(-flat_R[8], flat_R[10])\n",
    "\n",
    "    return values\n",
    "\n",
    "\n",
    "def compute_steps(acce_datas):\n",
    "    step_timestamps = np.array([])\n",
    "    step_indexs = np.array([], dtype=int)\n",
    "    step_acce_max_mins = np.zeros((0, 4))\n",
    "    sample_freq = 50\n",
    "    window_size = 22\n",
    "    low_acce_mag = 0.6\n",
    "    step_criterion = 1\n",
    "    interval_threshold = 250\n",
    "\n",
    "    acce_max = np.zeros((2,))\n",
    "    acce_min = np.zeros((2,))\n",
    "    acce_binarys = np.zeros((window_size,), dtype=int)\n",
    "    acce_mag_pre = 0\n",
    "    state_flag = 0\n",
    "\n",
    "    warmup_data = np.ones((window_size,)) * 9.81\n",
    "    filter_b, filter_a, filter_zf = init_parameters_filter(sample_freq, warmup_data)\n",
    "    acce_mag_window = np.zeros((window_size, 1))\n",
    "\n",
    "    # detect steps according to acceleration magnitudes\n",
    "    for i in np.arange(0, np.size(acce_datas, 0)):\n",
    "        acce_data = acce_datas[i, :]\n",
    "        acce_mag = np.sqrt(np.sum(acce_data[1:] ** 2))\n",
    "\n",
    "        acce_mag_filt, filter_zf = signal.lfilter(filter_b, filter_a, [acce_mag], zi=filter_zf)\n",
    "        acce_mag_filt = acce_mag_filt[0]\n",
    "\n",
    "        acce_mag_window = np.append(acce_mag_window, [acce_mag_filt])\n",
    "        acce_mag_window = np.delete(acce_mag_window, 0)\n",
    "        mean_gravity = np.mean(acce_mag_window)\n",
    "        acce_std = np.std(acce_mag_window)\n",
    "        mag_threshold = np.max([low_acce_mag, 0.4 * acce_std])\n",
    "\n",
    "        # detect valid peak or valley of acceleration magnitudes\n",
    "        acce_mag_filt_detrend = acce_mag_filt - mean_gravity\n",
    "        if acce_mag_filt_detrend > np.max([acce_mag_pre, mag_threshold]):\n",
    "            # peak\n",
    "            acce_binarys = np.append(acce_binarys, [1])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "        elif acce_mag_filt_detrend < np.min([acce_mag_pre, -mag_threshold]):\n",
    "            # valley\n",
    "            acce_binarys = np.append(acce_binarys, [-1])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "        else:\n",
    "            # between peak and valley\n",
    "            acce_binarys = np.append(acce_binarys, [0])\n",
    "            acce_binarys = np.delete(acce_binarys, 0)\n",
    "\n",
    "        if (acce_binarys[-1] == 0) and (acce_binarys[-2] == 1):\n",
    "            if state_flag == 0:\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "                state_flag = 1\n",
    "            elif (state_flag == 1) and ((acce_data[0] - acce_max[0]) <= interval_threshold) and (\n",
    "                    acce_mag_filt > acce_max[1]):\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "            elif (state_flag == 2) and ((acce_data[0] - acce_max[0]) > interval_threshold):\n",
    "                acce_max[:] = acce_data[0], acce_mag_filt\n",
    "                state_flag = 1\n",
    "\n",
    "        # choose reasonable step criterion and check if there is a valid step\n",
    "        # save step acceleration data: step_acce_max_mins = [timestamp, max, min, variance]\n",
    "        step_flag = False\n",
    "        if step_criterion == 2:\n",
    "            if (acce_binarys[-1] == -1) and ((acce_binarys[-2] == 1) or (acce_binarys[-2] == 0)):\n",
    "                step_flag = True\n",
    "        elif step_criterion == 3:\n",
    "            if (acce_binarys[-1] == -1) and (acce_binarys[-2] == 0) and (np.sum(acce_binarys[:-2]) > 1):\n",
    "                step_flag = True\n",
    "        else:\n",
    "            if (acce_binarys[-1] == 0) and acce_binarys[-2] == -1:\n",
    "                if (state_flag == 1) and ((acce_data[0] - acce_min[0]) > interval_threshold):\n",
    "                    acce_min[:] = acce_data[0], acce_mag_filt\n",
    "                    state_flag = 2\n",
    "                    step_flag = True\n",
    "                elif (state_flag == 2) and ((acce_data[0] - acce_min[0]) <= interval_threshold) and (\n",
    "                        acce_mag_filt < acce_min[1]):\n",
    "                    acce_min[:] = acce_data[0], acce_mag_filt\n",
    "        if step_flag:\n",
    "            step_timestamps = np.append(step_timestamps, acce_data[0])\n",
    "            step_indexs = np.append(step_indexs, [i])\n",
    "            step_acce_max_mins = np.append(step_acce_max_mins,\n",
    "                                           [[acce_data[0], acce_max[1], acce_min[1], acce_std ** 2]], axis=0)\n",
    "        acce_mag_pre = acce_mag_filt_detrend\n",
    "\n",
    "    return step_timestamps, step_indexs, step_acce_max_mins\n",
    "\n",
    "\n",
    "def compute_stride_length(step_acce_max_mins):\n",
    "    K = 0.4\n",
    "    K_max = 0.8\n",
    "    K_min = 0.4\n",
    "    para_a0 = 0.21468084\n",
    "    para_a1 = 0.09154517\n",
    "    para_a2 = 0.02301998\n",
    "\n",
    "    stride_lengths = np.zeros((step_acce_max_mins.shape[0], 2))\n",
    "    k_real = np.zeros((step_acce_max_mins.shape[0], 2))\n",
    "    step_timeperiod = np.zeros((step_acce_max_mins.shape[0] - 1, ))\n",
    "    stride_lengths[:, 0] = step_acce_max_mins[:, 0]\n",
    "    window_size = 2\n",
    "    step_timeperiod_temp = np.zeros((0, ))\n",
    "\n",
    "    # calculate every step period - step_timeperiod unit: second\n",
    "    for i in range(0, step_timeperiod.shape[0]):\n",
    "        step_timeperiod_data = (step_acce_max_mins[i + 1, 0] - step_acce_max_mins[i, 0]) / 1000\n",
    "        step_timeperiod_temp = np.append(step_timeperiod_temp, [step_timeperiod_data])\n",
    "        if step_timeperiod_temp.shape[0] > window_size:\n",
    "            step_timeperiod_temp = np.delete(step_timeperiod_temp, [0])\n",
    "        step_timeperiod[i] = np.sum(step_timeperiod_temp) / step_timeperiod_temp.shape[0]\n",
    "\n",
    "    # calculate parameters by step period and acceleration magnitude variance\n",
    "    k_real[:, 0] = step_acce_max_mins[:, 0]\n",
    "    k_real[0, 1] = K\n",
    "    for i in range(0, step_timeperiod.shape[0]):\n",
    "        k_real[i + 1, 1] = np.max([(para_a0 + para_a1 / step_timeperiod[i] + para_a2 * step_acce_max_mins[i, 3]), K_min])\n",
    "        k_real[i + 1, 1] = np.min([k_real[i + 1, 1], K_max]) * (K / K_min)\n",
    "\n",
    "    # calculate every stride length by parameters and max and min data of acceleration magnitude\n",
    "    stride_lengths[:, 1] = np.max([(step_acce_max_mins[:, 1] - step_acce_max_mins[:, 2]),\n",
    "                                   np.ones((step_acce_max_mins.shape[0], ))], axis=0)**(1 / 4) * k_real[:, 1]\n",
    "\n",
    "    return stride_lengths\n",
    "\n",
    "\n",
    "def compute_headings(ahrs_datas):\n",
    "    headings = np.zeros((np.size(ahrs_datas, 0), 2))\n",
    "    for i in np.arange(0, np.size(ahrs_datas, 0)):\n",
    "        ahrs_data = ahrs_datas[i, :]\n",
    "        rot_mat = get_rotation_matrix_from_vector(ahrs_data[1:])\n",
    "        azimuth, pitch, roll = get_orientation(rot_mat)\n",
    "        around_z = (-azimuth) % (2 * np.pi)\n",
    "        headings[i, :] = ahrs_data[0], around_z\n",
    "    return headings\n",
    "\n",
    "\n",
    "def compute_step_heading(step_timestamps, headings):\n",
    "    step_headings = np.zeros((len(step_timestamps), 2))\n",
    "    step_timestamps_index = 0\n",
    "    for i in range(0, len(headings)):\n",
    "        if step_timestamps_index < len(step_timestamps):\n",
    "            if headings[i, 0] == step_timestamps[step_timestamps_index]:\n",
    "                step_headings[step_timestamps_index, :] = headings[i, :]\n",
    "                step_timestamps_index += 1\n",
    "        else:\n",
    "            break\n",
    "    assert step_timestamps_index == len(step_timestamps)\n",
    "\n",
    "    return step_headings\n",
    "\n",
    "\n",
    "def compute_rel_positions(stride_lengths, step_headings):\n",
    "    rel_positions = np.zeros((stride_lengths.shape[0], 3))\n",
    "    for i in range(0, stride_lengths.shape[0]):\n",
    "        rel_positions[i, 0] = stride_lengths[i, 0]\n",
    "        rel_positions[i, 1] = -stride_lengths[i, 1] * np.sin(step_headings[i, 1])\n",
    "        rel_positions[i, 2] = stride_lengths[i, 1] * np.cos(step_headings[i, 1])\n",
    "\n",
    "    return rel_positions\n",
    "\n",
    "\n",
    "def compute_step_positions(acce_datas, ahrs_datas, posi_datas):\n",
    "    step_timestamps, step_indexs, step_acce_max_mins = compute_steps(acce_datas)\n",
    "    headings = compute_headings(ahrs_datas)\n",
    "    stride_lengths = compute_stride_length(step_acce_max_mins)\n",
    "    step_headings = compute_step_heading(step_timestamps, headings)\n",
    "    rel_positions = compute_rel_positions(stride_lengths, step_headings)\n",
    "    step_positions = correct_positions(rel_positions, posi_datas)\n",
    "\n",
    "    return step_positions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission = pd.read_csv('../input/indoor-location-navigation/sample_submission.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>building</th>\n",
       "      <th>path_id</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">5a0546857ecc773753327266</th>\n",
       "      <th>046cfa46be49fc10834815c6</th>\n",
       "      <td>[0000000000009, 0000000009017, 0000000015326, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>05d052dde78384b0c543d89c</th>\n",
       "      <td>[0000000000012, 0000000005748, 0000000014654, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0c06cc9f21d172618d74c6c8</th>\n",
       "      <td>[0000000000011, 0000000011818, 0000000019825, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146035943a1482883ed98570</th>\n",
       "      <td>[0000000000011, 0000000004535, 0000000011498, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1ef2771dfea25d508142ba06</th>\n",
       "      <td>[0000000000009, 0000000012833, 0000000021759, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                           timestamp\n",
       "building                 path_id                                                                    \n",
       "5a0546857ecc773753327266 046cfa46be49fc10834815c6  [0000000000009, 0000000009017, 0000000015326, ...\n",
       "                         05d052dde78384b0c543d89c  [0000000000012, 0000000005748, 0000000014654, ...\n",
       "                         0c06cc9f21d172618d74c6c8  [0000000000011, 0000000011818, 0000000019825, ...\n",
       "                         146035943a1482883ed98570  [0000000000011, 0000000004535, 0000000011498, ...\n",
       "                         1ef2771dfea25d508142ba06  [0000000000009, 0000000012833, 0000000021759, ..."
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_submission['building'] = [x.split('_')[0] for x in sample_submission['site_path_timestamp']]\n",
    "sample_submission['path_id'] = [x.split('_')[1] for x in sample_submission['site_path_timestamp']]\n",
    "sample_submission['timestamp'] = [x.split('_')[2] for x in sample_submission['site_path_timestamp']]\n",
    "samples = pd.DataFrame(sample_submission.groupby(['building','path_id'])['timestamp'].apply(lambda x: list(x)))\n",
    "buildings = np.unique([x[0] for x in samples.index])\n",
    "samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5a0546857ecc773753327266\n",
      "5c3c44b80379370013e0fd2b\n",
      "5d27075f03f801723c2e360f\n",
      "5d27096c03f801723c31e5e0\n",
      "5d27097f03f801723c320d97\n",
      "5d27099f03f801723c32511d\n",
      "5d2709a003f801723c3251bf\n",
      "5d2709b303f801723c327472\n",
      "5d2709bb03f801723c32852c\n",
      "5d2709c303f801723c3299ee\n",
      "5d2709d403f801723c32bd39\n",
      "5d2709e003f801723c32d896\n",
      "5da138274db8ce0c98bbd3d2\n",
      "5da1382d4db8ce0c98bbe92e\n",
      "5da138314db8ce0c98bbf3a0\n",
      "5da138364db8ce0c98bc00f1\n",
      "5da1383b4db8ce0c98bc11ab\n",
      "5da138754db8ce0c98bca82f\n",
      "5da138764db8ce0c98bcaa46\n",
      "5da1389e4db8ce0c98bd0547\n",
      "5da138b74db8ce0c98bd4774\n",
      "5da958dd46f8266d0737457b\n",
      "5dbc1d84c1eb61796cf7c010\n",
      "5dc8cea7659e181adb076a3f\n"
     ]
    }
   ],
   "source": [
    "from scipy.interpolate import interp1d\n",
    "from scipy.ndimage.filters import uniform_filter1d\n",
    "\n",
    "colacce = ['xyz_time','x_acce','y_acce','z_acce']\n",
    "colahrs = ['xyz_time','x_ahrs','y_ahrs','z_ahrs']\n",
    "\n",
    "for building in buildings:\n",
    "    print(building)\n",
    "    paths = samples.loc[building].index\n",
    "    # Acceleration info:\n",
    "    tfm = pd.read_csv(f'indoor_testing_accel/{building}.txt',index_col=0)\n",
    "    for path_id in paths:\n",
    "        # Original predicted values:\n",
    "        xy = result.loc[building+'_'+path_id]\n",
    "        tfmi = tfm.loc[path_id]\n",
    "        acce_datas = np.array(tfmi[colacce],dtype=np.float)\n",
    "        ahrs_datas = np.array(tfmi[colahrs],dtype=np.float)\n",
    "        posi_datas = np.array(xy[['t1_wifi','x','y']],dtype=np.float)\n",
    "        # Outlier removal:\n",
    "        xyout = uniform_filter1d(posi_datas,size=3,axis=0,mode='reflect')\n",
    "        xydiff = np.abs(posi_datas-xyout)\n",
    "        xystd = np.std(xydiff,axis=0)*3\n",
    "        posi_datas = posi_datas[(xydiff[:,1]<xystd[1])&(xydiff[:,2]<xystd[2])]\n",
    "        # Step detection:\n",
    "        step_timestamps, step_indexs, step_acce_max_mins = compute_steps(acce_datas)\n",
    "        stride_lengths = compute_stride_length(step_acce_max_mins)\n",
    "        # Orientation detection:\n",
    "        headings = compute_headings(ahrs_datas)\n",
    "        step_headings = compute_step_heading(step_timestamps, headings)\n",
    "        rel_positions = compute_rel_positions(stride_lengths, step_headings)\n",
    "        # Running average:\n",
    "        posi_datas = uniform_filter1d(posi_datas,size=3,axis=0,mode='reflect')[0::3,:]\n",
    "        # The 1st prediction timepoint should be earlier than the 1st step timepoint.\n",
    "        rel_positions = rel_positions[rel_positions[:,0]>posi_datas[0,0],:]\n",
    "        # If two consecutive predictions are in-between two step datapoints,\n",
    "        # the last one is removed, causing error (in the \"split_ts_seq\" function).\n",
    "        posi_index = [np.searchsorted(rel_positions[:,0], x, side='right') for x in posi_datas[:,0]]\n",
    "        u, i1, i2 = np.unique(posi_index, return_index=True, return_inverse=True)\n",
    "        posi_datas = np.vstack([np.mean(posi_datas[i2==i],axis=0) for i in np.unique(i2)])\n",
    "        # Position correction:\n",
    "        step_positions = correct_positions(rel_positions, posi_datas)\n",
    "        # Interpolate for timestamps in the testing set:\n",
    "\n",
    "        t = step_positions[:,0]\n",
    "        x = step_positions[:,1]\n",
    "        y = step_positions[:,2]\n",
    "        fx = interp1d(t, x, kind='linear', fill_value=(x[0],x[-1]), bounds_error=False) #fill_value=\"extrapolate\"\n",
    "        fy = interp1d(t, y, kind='linear', fill_value=(y[0],y[-1]), bounds_error=False)\n",
    "        # Output result:\n",
    "        t0 = np.array(samples.loc[(building,path_id),'timestamp'],dtype=np.float64)\n",
    "        sample_submission.loc[(sample_submission.building==building)&(sample_submission.path_id==path_id),'x'] = fx(t0)\n",
    "        sample_submission.loc[(sample_submission.building==building)&(sample_submission.path_id==path_id),'y'] = fy(t0)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "subold = pd.read_csv('submission_floor.csv')\n",
    "sample_submission['floor']=subold['floor']\n",
    "sample_submission[['site_path_timestamp','floor','x','y']].to_csv('submission_wifi.csv',index=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
